# 要添加一个新单元，输入 '# %%'
# 要添加一个新的标记单元，输入 '# %% [markdown]'
# %%
import os
import numpy as np

import scipy.misc
from matplotlib import pyplot as plt
import scipy.io as scio
from pathlib import Path

import pandas as pd

from metric_fun import psnr, corr_calc, snr_1d

from sklearn import linear_model
from custom_ksvd_dict_learning import KSVD
from sklearn.decomposition import DictionaryLearning

# %%
import warnings
warnings.filterwarnings("ignore")


# %%
import matplotlib as mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号


# %%
# class KSVD(object):
#     def __init__(self, n_components, max_iter=30, tol=1e-6,
#                  n_nonzero_coefs=None):
#         """
#         稀疏模型Y = DX，Y为样本矩阵，使用KSVD动态更新字典矩阵D和稀疏矩阵X
#         :param n_components: 字典所含原子个数（字典的列数）
#         :param max_iter: 最大迭代次数
#         :param tol: 稀疏表示结果的容差
#         :param n_nonzero_coefs: 稀疏度
#         """
#         self.dictionary = None
#         self.sparsecode = None
#         self.max_iter = max_iter
#         self.tol = tol
#         self.n_components = n_components
#         self.n_nonzero_coefs = n_nonzero_coefs

#     def _initialize(self, y):
#         """
#         初始化字典矩阵
#         """
#         u, s, v = np.linalg.svd(y)
#         self.dictionary = u[:, :self.n_components]

#     def _update_dict(self, y, d, x):
#         """
#         使用KSVD更新字典的过程
#         """
#         for i in range(self.n_components):
#             index = np.nonzero(x[i, :])[0]
#             if len(index) == 0:
#                 continue

#             d[:, i] = 0
#             r = (y - np.dot(d, x))[:, index]
#             u, s, v = np.linalg.svd(r, full_matrices=False)
#             d[:, i] = u[:, 0].T
#             x[i, index] = s[0] * v[0, :]
#         return d, x

#     def fit(self, y):
#         """
#         KSVD迭代过程
#         """
#         self._initialize(y)
#         for i in range(self.max_iter):
#             x = linear_model.orthogonal_mp(self.dictionary, y, n_nonzero_coefs=self.n_nonzero_coefs)
#             e = np.linalg.norm(y - np.dot(self.dictionary, x))
#             if e < self.tol:
#                 break
#             self._update_dict(y, self.dictionary, x)

#         self.sparsecode = linear_model.orthogonal_mp(self.dictionary, y, n_nonzero_coefs=self.n_nonzero_coefs)
#         self.components_ = self.dictionary
#         return self


# %%
def read_data_from_mat(mat_path):
    dataFile_path = mat_path
    data = scio.loadmat(dataFile_path)
    return data

def save_np_data(np_array, save_path):
    np.save(save_path,np_array)


# 从mat文件中读取列名
def pre_read_data(mat_path, col_name):
    mat_data = read_data_from_mat(mat_path)
    mat_data_full_path = os.path.abspath(mat_path)
    np_data_dir_path = os.path.abspath(os.path.dirname(mat_data_full_path))
    
    np_data_name = (os.path.basename(mat_path)).split('.mat')[0] + '.txt'
    np_data_path = os.path.join(np_data_dir_path, np_data_name)

    print("mat_data_type:", type(mat_data[col_name]))
    print("mat_data_length:", len(mat_data[col_name]))
    # print('mat_data:\n', mat_data[col_name])
    print('np_data_path:', np_data_path)
    np_data = np.zeros(len(mat_data[col_name]))
    np_data = mat_data[col_name]
    print('np_data_shape:', np_data.shape)
    np_data = np_data.reshape(len(np_data),)
    print('np_data:\n', np_data)
    print('np_data_shape:', np_data.shape)
    np.savetxt(np_data_path, np_data, encoding='gbk')
    return np_data_path

def load_txtfile(txt_file_path):
    np_data = np.loadtxt(txt_file_path)
    return np_data



def calc_mean_with_nearest_window(t, window_width):
    # t 是一个一维数据组
    data_length = len(t)
    nearest_mean = np.zeros(data_length)
    for index in range(data_length):
    # 位于第一个窗长之后的
        # t_list = np.zeros(window_width)
        t_list =  []
        if index == 0:
            t_list.append(t[index])
        elif 0 < index < window_width:
            t_list.append((t[index-1] + t[index] + t[index+1])/3)     
        elif((index >= window_width) & (index < (data_length - window_width))):
            for i in range(window_width):
                if i <= int(window_width/2):
                    # t_list[i] = t[index-i]
                    t_list.append(t[index-i])
                else:
                    # t_list[i] = t[index+i]
                    t_list.append(t[index+i])
        else:
            #index in range(data_length - window_width-1,data_length):
            t_list.append((t[index-2] + t[index-1] + t[index])/3)  
        t_list_arr = np.array(t_list)
        # nearest_mean[index] = np.average(t_list)
        nearest_mean[index] = np.average(t_list_arr)
    return nearest_mean

# 等同于MATLAB中的smooth函数，
# 但是平滑窗口必须为奇数。
# yy = smooth(y) smooths the data in the column vector y ..
# The first few elements of yy are given by
# yy(1) = y(1)
# yy(2) = (y(1) + y(2) + y(3))/3
# yy(3) = (y(1) + y(2) + y(3) + y(4) + y(5))/5
# yy(4) = (y(2) + y(3) + y(4) + y(5) + y(6))/5
# ...
# https://blog.csdn.net/weixin_40532625/article/details/91950668

# %% [markdown]
# # Step 1: 预读取数据

# %%

orgin_data_path = "data/x_snr20.mat"
txt_orgin_data_path = pre_read_data(orgin_data_path, 'x')


# %%
orgin_data_distorted_path = "data/y_snr20.mat"
txt_orgin_data_distorted_path = pre_read_data(orgin_data_distorted_path, 'y')


# %%



# %%
# 判断文件夹是否存在
dir_path = os.path.abspath(
    os.path.dirname(os.path.dirname(txt_orgin_data_distorted_path)))
csv_dir_path =  os.path.join(dir_path, 'results/')
results_image_dir_path =  csv_dir_path
if os.path.exists(csv_dir_path):
    csv_data_dir_path = csv_dir_path
else:
    os.mkdir(csv_dir_path)
    csv_data_dir_path = csv_dir_path

# %% [markdown]
# # Step 2: 读取数据

# %%
orgin_data = load_txtfile(txt_orgin_data_path)
orgin_data_with_distorted = load_txtfile(txt_orgin_data_distorted_path)


# %%
noise_std = np.std(orgin_data_with_distorted - orgin_data)
noise_std

# %% [markdown]
# ## 是否对信号滑窗

# %%
# 是否滑窗
smooth_flag = False  # False | True

# 滑窗大小
window_width = 7

# 滑窗范围
smooth_length =  [0, 2500]          # default : [0, len(orgin_data_with_distorted)]


orgin_data_distorted = None


# %%
if smooth_flag == False:
    orgin_data_distorted = orgin_data_with_distorted
else:
    # 执行滑窗操作
    ## 设置滑窗段
    orgin_data_with_distorted_unsmooth = orgin_data_with_distorted[smooth_length[0]:smooth_length[1]]
    ## 平滑操作
    orgin_data_with_distorted_smooth_part = calc_mean_with_nearest_window(orgin_data_with_distorted_unsmooth, window_width)
    ## 拼接滑窗的和未滑窗的部分
    part_1 = orgin_data_with_distorted[0:smooth_length[0]]
    part_2 = orgin_data_with_distorted[smooth_length[1]:]
    orgin_data_distorted_full_unsmooth = orgin_data_with_distorted
    orgin_data_distorted = np.concatenate([part_1,orgin_data_with_distorted_smooth_part,part_2])


# %%
import datetime
def tid_maker():
    # return '{0:%Y%m%d%H%M}'.format(datetime.datetime.now())
    return '{0:%Y_%m%d_%H%M}'.format(datetime.datetime.now())


# %%


# %% [markdown]
# # Step 3: 将1维度数据转为2维度

# %%
data_length = len(orgin_data)


# %%
orgin_data


# %%
data_2d_shape = [50,50]

orgin_data_2d = orgin_data.reshape(data_2d_shape[0],data_2d_shape[1])
orgin_data_distorted_2d   = orgin_data_distorted.reshape(data_2d_shape[0],data_2d_shape[1])
# orgin_data_with_distorted_2d   = orgin_data_with_distorted.reshape(data_2d_shape[0],data_2d_shape[1])



# im_ascent = scipy.misc.ascent().astype(np.float)
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(orgin_data_2d)
plt.subplot(1, 2, 2)
plt.imshow(orgin_data_distorted_2d)
plt.tight_layout()
plt.show()


# %%
# ksvd = KSVD(n_components = 50, max_iter =100, tol = 1e-6, n_nonzero_coefs = 10, transform_algorithm='omp')
ksvd = KSVD(n_components = 50, max_iter = 100, tol = 1e-6,  n_nonzero_coefs = 10, transform_algorithm='omp')
orgin_data_reconstruct_model = ksvd.fit(orgin_data_2d)
ksvd_1 = KSVD(n_components = 50, max_iter = 100, tol = 1e-6,  n_nonzero_coefs = 10,transform_algorithm='omp')
orgin_data_distorted_reconstruct_model = ksvd_1.fit(orgin_data_distorted_2d)

dictionary, sparsecode = orgin_data_reconstruct_model.components_, orgin_data_reconstruct_model.sparsecode
dictionary_1, sparsecode_1 = orgin_data_distorted_reconstruct_model.components_, orgin_data_distorted_reconstruct_model.sparsecode


# %%
A = dictionary.dot(sparsecode)
B = dictionary_1.dot(sparsecode_1)
plt.figure(figsize=(10,10))
plt.subplot(2, 2, 1)
plt.imshow(orgin_data_2d)
plt.title("orgin_data_2d")
plt.subplot(2, 2, 2)
plt.imshow(A)
plt.title("orgin_data_2d_reconstruct")
plt.subplot(2, 2, 3)
plt.imshow(orgin_data_distorted_2d)
plt.title("orgin_data_distorted_2d")
plt.subplot(2, 2, 4)
plt.imshow(B)
plt.title("orgin_data_distorted_2d_reconstruct")
plt.tight_layout()
plt.savefig(os.path.join(results_image_dir_path, 'k-svd重构结果2D效果图_' + tid_maker() + '.png'))
plt.show()


# %%
reconstruct_orgin_data_2d =   dictionary @ sparsecode
reconstruct_orgin_data_distorted_2d =  dictionary_1  @ sparsecode_1


# %%
reconstruct_orgin_data_1d = reconstruct_orgin_data_2d.reshape(len(orgin_data))
reconstruct_orgin_data_distorted_1d = reconstruct_orgin_data_distorted_2d.reshape(len(orgin_data))


# %%
dictionary.shape


# %%
plt.figure(figsize=(16,8))
plt.subplot(1, 2, 1)
plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_1d, color = 'blue', label = 'reconstruct_data_1d')
plt.legend()
plt.grid()

plt.subplot(1, 2, 2)
# plt.plot(np.arange(len(orgin_data)), orgin_data_with_distorted, color = 'black', label = 'orgin_data_with_distorted')
plt.plot(np.arange(len(orgin_data)), orgin_data_distorted, color = 'black', label = 'orgin_data_distorted')
plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_distorted_1d, color = 'red', label = 'reconstruct_data_distorted_1d')
plt.legend()
plt.grid()
plt.tight_layout()
plt.savefig(os.path.join(results_image_dir_path, 'k-svd重构结果1D效果图_' + tid_maker() + '.png'))
plt.show()






#################################################### 执行去噪###########################################

from time import time

# from sklearn.feature_extraction.image import extract_patches_2d
# from sklearn.feature_extraction.image import reconstruct_from_patches_2d
from make_patchs import extract_patches_2d
from make_patchs import reconstruct_from_patches_2d


# %%
# #############################################################################
# Display the distorted image

def show_with_diff(image, reference, title):
    """Helper function to display denoising"""
    plt.figure(figsize=(5, 3.3))
    plt.subplot(1, 2, 1)
    plt.title('Image')
    plt.imshow(image, vmin=0, vmax=1, cmap=plt.cm.gray,
               interpolation='nearest')
    plt.xticks(())
    plt.yticks(())
    plt.subplot(1, 2, 2)
    difference = image - reference

    plt.title('Difference(RMSE) (norm: %.2f)' % np.sqrt(np.sum(difference ** 2)))
    plt.imshow(difference, vmin=-0.5, vmax=0.5, cmap=plt.cm.PuOr,
               interpolation='nearest')
    plt.xticks(())
    plt.yticks(())
    plt.suptitle(title, size=16)
    plt.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2)
    plt.tight_layout()



# #############################################################################


# %%
# a = np.array([1,2,3,4,5,6]).reshape(-1,1)
# a


# # %%
# b = a.reshape(3,2)
# b

# %% [markdown]
# # 使用重构数据去噪
# %% [markdown]
# class KSVD(n_components, max_iter=30, tol=1e-6,
#                  n_nonzero_coefs=None):
#         """
#         稀疏模型Y = DX，Y为样本矩阵，使用KSVD动态更新字典矩阵D和稀疏矩阵X  
#         :param n_components: 字典所含原子个数（字典的列数）  
#         :param max_iter: 最大迭代次数  
#         :param tol: 稀疏表示结果的容差  
#         :param n_nonzero_coefs: 稀疏度  
#         """



# %%
# #############################################################################
# Extract noisy patches and reconstruct them using the dictionary


# patch_size = (10, 10)
# patch_size = (5, 5)
patch_size = (10, 10)
# patch_size = (7, 7)
# height, width = reconstruct_orgin_data_distorted_2d.shape
height, width = reconstruct_orgin_data_2d.shape
print('Extracting noisy patches... ')
t0 = time()
# data = extract_patches_2d(orgin_data_2d, patch_size, pointed_step=1)
data = extract_patches_2d(orgin_data_distorted_2d, patch_size, pointed_step=1)
# 提取前N个patch训练data = data[:N,:]
# print('data_patchs.shape:',data.shape)

# 合并所有patch
# all_data_patch = np.concatenate((recon_data_patch, orgin_noise_data),axis=1)
# 训练字典使用随机样本——不好用
# use_random_patchs = True
# random_patchs_num = len(data)

# rand_items = []
# if use_random_patchs == True:
#     rand_seeds = np.random.randint(low=0, high=len(data),size =random_patchs_num)
#     print('rand_seeds.shape:',rand_seeds.shape)
#     # print(rand_seeds)
#     for i in rand_seeds:
#       print(i)     
#       rand_item = data[i]
#       rand_items.append(rand_item)
#     data = np.array(rand_items)
rng = np.random.RandomState(0)
# print('data_final_patchs.shape:',data.shape)
print('data_final_patchs.shape:',data.shape)
data = data.reshape(data.shape[0], -1)
intercept = np.mean(data, axis=0)
data -= intercept
print('done in %.2fs.' % (time() - t0))

transform_algorithms = [
    ('Orthogonal Matching Pursuit\n1 atom', 'omp',
     {'n_nonzero_coefs': 1}),
    ('Orthogonal Matching Pursuit\n2 atoms', 'omp',
     {'n_nonzero_coefs': 2}),
   ('Orthogonal Matching Pursuit\n3 atoms', 'omp',
     {'n_nonzero_coefs': 3}),
   ('Orthogonal Matching Pursuit\4 atoms', 'omp',
     {'n_nonzero_coefs': 4}),
    ('Orthogonal Matching Pursuit\n5 atoms', 'omp',
     {'n_nonzero_coefs': 5}),
    ('Orthogonal Matching Pursuit\n8 atoms', "omp",
     {'n_nonzero_coefs': 8})]


# 重新基于噪声创建编码
# dico_ksvd = KSVD(n_components = 200, max_iter =100, tol = 1e-5, n_nonzero_coefs = 10) 
dico_dict = KSVD(n_components = 50, max_iter = 100, tol = 1e-5,  n_nonzero_coefs = 10, transform_algorithm='omp')

dico_dict_model = dico_dict.fit(data)
V = dico_dict_model.components_


loss = dico_dict_model.loss


# plt.figure()
# epoch = np.arange(len(loss))
# plt.plot(epoch, loss, label='loss')
# plt.show()


# %%
# import numpy as np
# data = [[1,2],[3,4]]
# c = np.mean(data, axis=0)
# c
height, width = orgin_data_distorted_2d.shape
print('Extracting noisy patches... ')
t0 = time()
# data = extract_patches_2d(reconstruct_orgin_data_distorted_2d, patch_size, pointed_step=1)
data = extract_patches_2d(orgin_data_distorted_2d, patch_size, pointed_step=1)
print('data_final_patchs.shape:',data.shape)
data = data.reshape(data.shape[0], -1)
intercept = np.mean(data, axis=0)
data -= intercept
print('done in %.2fs.' % (time() - t0))


reconstructions_data = {}
results_psnr = {}
results_rmse = {}

titles = []
count = 0
for title, transform_algorithm, kwargs in transform_algorithms:
    titles.append(title)
    print(title + '...')
    reconstructions_data[title] = reconstruct_orgin_data_2d.copy()
    t0 = time()
    print(transform_algorithm)
    dico_dict.set_params(transform_algorithm=transform_algorithm, **kwargs)
    # code = dico_ksvd.fit(data)
    code_new = dico_dict.fit(data).sparsecode

    loss_new = dico_dict.loss
    plt.figure()
    epoch_new = np.arange(len(loss_new))
    plt.plot(epoch_new, loss_new, label='loss')
    plt.xlabel('iterations')
    plt.ylabel('loss')
    plt.show()


    # 有噪声字典
    patches = np.dot(V, code_new)
    # 无噪声字典
    # patches =  np.dot(nonoise_V, code)  # orgin_V.dot(code)

    patches += intercept
    patches = patches.reshape(len(data), *patch_size)
    if transform_algorithm == 'threshold':
        patches -= patches.min()
        patches /= patches.max()
    reconstructions_data[title] = reconstruct_from_patches_2d(
        patches, (height, width))
    dt = time() - t0
    count = count + 1
    print('done in %.2fs.' % dt)
    # results_psnr[title], results_rmse[title]  = psnr(reconstructions_data[title], orgin_data_distorted_2d)
    results_psnr[title], results_rmse[title]  = psnr(reconstructions_data[title], orgin_data_2d)
    
    show_with_diff(reconstructions_data[title], reconstruct_orgin_data_2d,  # orgin_data_distorted_2d
                   title + ' (time: %.1fs' % dt + ' psnr: %.2f'  % results_psnr[title] + ' rmse: %.4f)'  % results_rmse[title])

plt.show()






# # 保存字典

# %%
model_save_path = 'dict_model/'

dict_mat_name = 'k-svd_read' + '_dict_patchsize_noise_' + str(patch_size[0]) + '.mat'
scio.savemat(os.path.join(model_save_path,dict_mat_name), {'noise_dict_code': V})


# %%
reconstructions_data_new = {}
count = 0
for title, transform_algorithm, kwargs in transform_algorithms:
    reconstructions_data_new[count]  = reconstructions_data[title]
    count = count + 1


# %%
np_array_title = np.array(titles)


# %%
titles


# %%
row_nums = (len(titles) + 2) / 2
row_nums


# %%
plt.figure(figsize=(10,5 * row_nums))
plt.subplot(row_nums, 2, 1)
plt.imshow(orgin_data_2d)
plt.title("orgin_data_2d")
plt.subplot(row_nums, 2, 2)
plt.imshow(orgin_data_distorted_2d)
plt.title("orgin_data_distorted_2d")

for i in range(len(titles)):
    plt.subplot(row_nums, 2, 3 + i)
    plt.imshow(reconstructions_data_new[i])
    title = titles[i]
    plt.title(np_array_title[i] + ' psnr: %.2fdb'  % results_psnr[title] + ' rmse: %.4f'  % results_rmse[title])
plt.tight_layout(pad = 1, w_pad = 1, h_pad=1)
plt.savefig(os.path.join(results_image_dir_path, 'k-svd去噪结果2D效果图_' + tid_maker() + '.png'))
plt.show()


# %%
reconstruct_orgin_data_2d_denoise_list = []

for i in range(len(titles)):
    reconstruct_orgin_data_2d_denoise_list.append(reconstructions_data_new[i])
reconstruct_orgin_data_2d_denoise = np.array(reconstruct_orgin_data_2d_denoise_list)


reconstruct_orgin_data_1d_denoise_list = []
count_id = 0
for title in titles:
    reconstruct_orgin_data_1d_denoise_list.append(reconstruct_orgin_data_2d_denoise[count_id].reshape(len(orgin_data_distorted)))    
    count_id = count_id + 1
reconstruct_orgin_data_1d_denoise = np.array(reconstruct_orgin_data_1d_denoise_list)




# orgin_data_distorted.shape,reconstruct_orgin_data_1d_denoise[0].shape


resdual_list = []
for i in range(len(titles)):  
    resdual_list.append(reconstruct_orgin_data_1d_denoise[i] - orgin_data)

resdual_array = np.array(resdual_list)



# resdual_array.shape

# %% [markdown]
# ## 计算相关系数

# %%
curves_num = len(reconstruct_orgin_data_1d_denoise)
results_corr = np.zeros(curves_num)
denoise_snr =np.zeros(curves_num)
count_id = 0
for i in range(curves_num):
    results_corr[i] = corr_calc(reconstruct_orgin_data_1d_denoise[count_id],orgin_data)
    denoise_snr[i] = snr_1d(reconstruct_orgin_data_1d_denoise[count_id],resdual_array[i])
    count_id = count_id + 1


for i in range(len(titles)):
    # print(titles[i])
    line_label = titles[i].replace("\n", " ")
    print(line_label)


plt.figure(figsize=(18,8 * row_nums))
plt.subplot(row_nums, 2, 1)

plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_1d, color = 'red', label = 'reconstruct_data_1d')
plt.legend()
plt.grid()

plt.subplot( row_nums, 2, 2)
plt.plot(np.arange(len(orgin_data)), orgin_data_with_distorted, color = 'green', label = 'orgin_data_with_distorted')
plt.plot(np.arange(len(orgin_data)), orgin_data_distorted, color = 'black', label = 'orgin_data_distorted_smooth')
plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_distorted_1d, color = 'red', label = 'reconstruct_data_distorted_1d')
plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
plt.legend()
plt.grid()

# resdual_array
for i in range(len(titles)):
    plt.subplot(row_nums, 2, 3+i)
    plt.plot(np.arange(len(orgin_data)), orgin_data_with_distorted, color = 'black', label = 'orgin_data_with_distorted')
    line_label = titles[i].replace("\n", " ")
    plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_1d_denoise[i], color = 'red', label =line_label)
    plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
    plt.plot(np.arange(len(orgin_data)), resdual_array[i], color = 'blue', label = 'resdual')
    plt.title('R: %.4f' %results_corr[i] +  ' PSNR: %.2fdB'  %results_psnr[title]  + ' RMSE: %.4f'  %results_rmse[title] + ' SNR %.2f' %denoise_snr[i])
    plt.legend()
    plt.grid()
plt.tight_layout(pad = 1, w_pad = 1, h_pad=1)
plt.savefig(os.path.join(results_image_dir_path, 'k-svd重构去噪结果' + '_patchsize_' + str(patch_size) +  tid_maker() + '.png'))
plt.show()


###############################第二次去噪##################################################
# 第一次训练完无噪声的字典之后，用带噪声的提取patch去重建，使用下面第三行开始的
# 第二次迭代使用下面一行

# patch_size = (8, 8)

# reconstruct_orgin_data_distorted_2d = reconstruct_orgin_data_1d_denoise[0].reshape(50,50)
# data = extract_patches_2d(reconstruct_orgin_data_distorted_2d, patch_size,pointed_step=1)
# data = data.reshape(data.shape[0], -1)
# intercept = np.mean(data, axis=0)
# data -= intercept



# reconstructions_data = {}
# results_psnr = {}
# results_rmse = {}

# titles = []
# count = 0
# for title, transform_algorithm, kwargs in transform_algorithms:
#     titles.append(title)
#     print(title + '...')
#     reconstructions_data[title] = reconstruct_orgin_data_2d.copy()
#     t0 = time()
#     print(transform_algorithm)
#     dico_ksvd.set_params(transform_algorithm=transform_algorithm, **kwargs)
#     # code = dico_ksvd.transform(data)
#     code = dico_ksvd.fit(data).sparsecode
#     print(code.shape)
#     # print(orgin_V.shape)
#     # patches = np.dot(code, V)  # modified——>
#     # 有噪声字典
#     patches = np.dot(V, code)
#     # 无噪声字典
#     # patches =  np.dot(nonoise_V, code)  # orgin_V.dot(code)

#     patches += intercept
#     patches = patches.reshape(len(data), *patch_size)
#     if transform_algorithm == 'threshold':
#         patches -= patches.min()
#         patches /= patches.max()
#     reconstructions_data[title] = reconstruct_from_patches_2d(
#         patches, (height, width))
#     dt = time() - t0
#     count = count + 1
#     print('done in %.2fs.' % dt)
#     # results_psnr[title], results_rmse[title]  = psnr(reconstructions_data[title], orgin_data_distorted_2d)
#     results_psnr[title], results_rmse[title]  = psnr(reconstructions_data[title], orgin_data_2d)
    
#     show_with_diff(reconstructions_data[title], reconstruct_orgin_data_2d,  # orgin_data_distorted_2d
#                    title + ' (time: %.1fs' % dt + ' psnr: %.2f'  % results_psnr[title] + ' rmse: %.4f)'  % results_rmse[title])

# plt.show()



# reconstructions_data_new = {}
# count = 0
# for title, transform_algorithm, kwargs in transform_algorithms:
#     reconstructions_data_new[count]  = reconstructions_data[title]
#     count = count + 1


# np_array_title = np.array(titles)

# row_nums = (len(titles) + 2) / 2


# plt.figure(figsize=(10,5 * row_nums))
# plt.subplot(row_nums, 2, 1)
# plt.imshow(orgin_data_2d)
# plt.title("orgin_data_2d")
# plt.subplot(row_nums, 2, 2)
# plt.imshow(orgin_data_distorted_2d)
# plt.title("orgin_data_distorted_2d")

# for i in range(len(titles)):
#     plt.subplot(row_nums, 2, 3 + i)
#     plt.imshow(reconstructions_data_new[i])
#     title = titles[i]
#     plt.title(np_array_title[i] + ' psnr: %.2fdb'  % results_psnr[title] + ' rmse: %.4f'  % results_rmse[title])
# plt.tight_layout(pad = 1, w_pad = 1, h_pad=1)
# plt.savefig(os.path.join(results_image_dir_path, 'k-svd第二次去噪结果2D效果图_' + tid_maker() + '.png'))
# plt.show()


# reconstruct_orgin_data_2d_denoise_list = []

# for i in range(len(titles)):
#     reconstruct_orgin_data_2d_denoise_list.append(reconstructions_data_new[i])
# reconstruct_orgin_data_2d_denoise = np.array(reconstruct_orgin_data_2d_denoise_list)



# reconstruct_orgin_data_1d_denoise_list = []
# count_id = 0
# for title in titles:
#     reconstruct_orgin_data_1d_denoise_list.append(reconstruct_orgin_data_2d_denoise[count_id].reshape(len(orgin_data_distorted)))    
#     count_id = count_id + 1
# reconstruct_orgin_data_1d_denoise = np.array(reconstruct_orgin_data_1d_denoise_list)



# orgin_data_distorted.shape,reconstruct_orgin_data_1d_denoise[0].shape


# resdual_list = []
# for i in range(len(titles)):  
#     resdual_list.append(reconstruct_orgin_data_1d_denoise[i] - orgin_data)

# resdual_array = np.array(resdual_list)


# resdual_array.shape

# # ## 计算相关系数

# curves_num = len(reconstruct_orgin_data_1d_denoise)
# results_corr = np.zeros(curves_num)
# denoise_snr =np.zeros(curves_num)
# count_id = 0
# for i in range(curves_num):
#     results_corr[i] = corr_calc(reconstruct_orgin_data_1d_denoise[count_id],orgin_data)
#     denoise_snr[i] = snr_1d(reconstruct_orgin_data_1d_denoise[count_id],resdual_array[i])
#     count_id = count_id + 1



# for i in range(len(titles)):
#     # print(titles[i])
#     line_label = titles[i].replace("\n", " ")
#     print(line_label)


# plt.figure(figsize=(18,8 * row_nums))
# plt.subplot(row_nums, 2, 1)

# plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
# plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_1d, color = 'red', label = 'reconstruct_data_1d')
# plt.legend()
# plt.grid()

# plt.subplot( row_nums, 2, 2)
# plt.plot(np.arange(len(orgin_data)), orgin_data_with_distorted, color = 'green', label = 'orgin_data_with_distorted')
# plt.plot(np.arange(len(orgin_data)), orgin_data_distorted, color = 'black', label = 'orgin_data_distorted_smooth')
# plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_distorted_1d, color = 'red', label = 'reconstruct_data_distorted_1d')
# plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
# plt.legend()
# plt.grid()

# # resdual_array
# for i in range(len(titles)):
#     plt.subplot(row_nums, 2, 3+i)
#     plt.plot(np.arange(len(orgin_data)), orgin_data_with_distorted, color = 'black', label = 'orgin_data_with_distorted')
#     line_label = titles[i].replace("\n", " ")
#     plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_1d_denoise[i], color = 'red', label =line_label)
#     plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
#     plt.plot(np.arange(len(orgin_data)), resdual_array[i], color = 'blue', label = 'resdual')
#     plt.title('R: %.4f' %results_corr[i] +  ' PSNR: %.2fdB'  %results_psnr[title]  + ' RMSE: %.4f'  %results_rmse[title] + ' SNR %.2f' %denoise_snr[i])
#     plt.legend()
#     plt.grid()
# plt.tight_layout(pad = 1, w_pad = 1, h_pad=1)
# plt.savefig(os.path.join(results_image_dir_path, 'k-svd重构第二次去噪结果' + '_patchsize_' + str(patch_size) +  tid_maker() + '.png'))
# plt.show()

# ########################################第三次去噪###############################################
# reconstruct_orgin_data_distorted_2d = reconstruct_orgin_data_1d_denoise[0].reshape(50,50)
# data = extract_patches_2d(reconstruct_orgin_data_distorted_2d, patch_size,pointed_step=1)
# data = data.reshape(data.shape[0], -1)
# intercept = np.mean(data, axis=0)
# data -= intercept



# reconstructions_data = {}
# results_psnr = {}
# results_rmse = {}

# titles = []
# count = 0
# for title, transform_algorithm, kwargs in transform_algorithms:
#     titles.append(title)
#     print(title + '...')
#     reconstructions_data[title] = reconstruct_orgin_data_2d.copy()
#     t0 = time()
#     print(transform_algorithm)
#     dico_ksvd.set_params(transform_algorithm=transform_algorithm, **kwargs)
#     # code = dico_ksvd.transform(data)
#     code = dico_ksvd.fit(data).sparsecode
#     print(code.shape)
#     # print(orgin_V.shape)
#     # patches = np.dot(code, V)  # modified——>
#     # 有噪声字典
#     patches = np.dot(V, code)
#     # 无噪声字典
#     # patches =  np.dot(nonoise_V, code)  # orgin_V.dot(code)

#     patches += intercept
#     patches = patches.reshape(len(data), *patch_size)
#     if transform_algorithm == 'threshold':
#         patches -= patches.min()
#         patches /= patches.max()
#     reconstructions_data[title] = reconstruct_from_patches_2d(
#         patches, (height, width))
#     dt = time() - t0
#     count = count + 1
#     print('done in %.2fs.' % dt)
#     # results_psnr[title], results_rmse[title]  = psnr(reconstructions_data[title], orgin_data_distorted_2d)
#     results_psnr[title], results_rmse[title]  = psnr(reconstructions_data[title], orgin_data_2d)
    
#     show_with_diff(reconstructions_data[title], reconstruct_orgin_data_2d,  # orgin_data_distorted_2d
#                    title + ' (time: %.1fs' % dt + ' psnr: %.2f'  % results_psnr[title] + ' rmse: %.4f)'  % results_rmse[title])

# plt.show()



# reconstructions_data_new = {}
# count = 0
# for title, transform_algorithm, kwargs in transform_algorithms:
#     reconstructions_data_new[count]  = reconstructions_data[title]
#     count = count + 1


# np_array_title = np.array(titles)

# row_nums = (len(titles) + 2) / 2


# plt.figure(figsize=(10,5 * row_nums))
# plt.subplot(row_nums, 2, 1)
# plt.imshow(orgin_data_2d)
# plt.title("orgin_data_2d")
# plt.subplot(row_nums, 2, 2)
# plt.imshow(orgin_data_distorted_2d)
# plt.title("orgin_data_distorted_2d")

# for i in range(len(titles)):
#     plt.subplot(row_nums, 2, 3 + i)
#     plt.imshow(reconstructions_data_new[i])
#     title = titles[i]
#     plt.title(np_array_title[i] + ' psnr: %.2fdb'  % results_psnr[title] + ' rmse: %.4f'  % results_rmse[title])
# plt.tight_layout(pad = 1, w_pad = 1, h_pad=1)
# plt.savefig(os.path.join(results_image_dir_path, 'k-svd第三次去噪结果2D效果图_' + tid_maker() + '.png'))
# plt.show()


# reconstruct_orgin_data_2d_denoise_list = []

# for i in range(len(titles)):
#     reconstruct_orgin_data_2d_denoise_list.append(reconstructions_data_new[i])
# reconstruct_orgin_data_2d_denoise = np.array(reconstruct_orgin_data_2d_denoise_list)



# reconstruct_orgin_data_1d_denoise_list = []
# count_id = 0
# for title in titles:
#     reconstruct_orgin_data_1d_denoise_list.append(reconstruct_orgin_data_2d_denoise[count_id].reshape(len(orgin_data_distorted)))    
#     count_id = count_id + 1
# reconstruct_orgin_data_1d_denoise = np.array(reconstruct_orgin_data_1d_denoise_list)



# orgin_data_distorted.shape,reconstruct_orgin_data_1d_denoise[0].shape


# resdual_list = []
# for i in range(len(titles)):  
#     resdual_list.append(reconstruct_orgin_data_1d_denoise[i] - orgin_data)

# resdual_array = np.array(resdual_list)


# resdual_array.shape

# # ## 计算相关系数

# curves_num = len(reconstruct_orgin_data_1d_denoise)
# results_corr = np.zeros(curves_num)
# denoise_snr =np.zeros(curves_num)
# count_id = 0
# for i in range(curves_num):
#     results_corr[i] = corr_calc(reconstruct_orgin_data_1d_denoise[count_id],orgin_data)
#     denoise_snr[i] = snr_1d(reconstruct_orgin_data_1d_denoise[count_id],resdual_array[i])
#     count_id = count_id + 1



# for i in range(len(titles)):
#     # print(titles[i])
#     line_label = titles[i].replace("\n", " ")
#     print(line_label)


# plt.figure(figsize=(18,8 * row_nums))
# plt.subplot(row_nums, 2, 1)

# plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
# plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_1d, color = 'red', label = 'reconstruct_data_1d')
# plt.legend()
# plt.grid()

# plt.subplot( row_nums, 2, 2)
# plt.plot(np.arange(len(orgin_data)), orgin_data_with_distorted, color = 'green', label = 'orgin_data_with_distorted')
# plt.plot(np.arange(len(orgin_data)), orgin_data_distorted, color = 'black', label = 'orgin_data_distorted_smooth')
# plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_distorted_1d, color = 'red', label = 'reconstruct_data_distorted_1d')
# plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
# plt.legend()
# plt.grid()

# # resdual_array
# for i in range(len(titles)):
#     plt.subplot(row_nums, 2, 3+i)
#     plt.plot(np.arange(len(orgin_data)), orgin_data_with_distorted, color = 'black', label = 'orgin_data_with_distorted')
#     line_label = titles[i].replace("\n", " ")
#     plt.plot(np.arange(len(orgin_data)), reconstruct_orgin_data_1d_denoise[i], color = 'red', label =line_label)
#     plt.plot(np.arange(len(orgin_data)), orgin_data, color = 'orange', label = 'orgin_data')
#     plt.plot(np.arange(len(orgin_data)), resdual_array[i], color = 'blue', label = 'resdual')
#     plt.title('R: %.4f' %results_corr[i] +  ' PSNR: %.2fdB'  %results_psnr[title]  + ' RMSE: %.4f'  %results_rmse[title] + ' SNR %.2f' %denoise_snr[i])
#     plt.legend()
#     plt.grid()
# plt.tight_layout(pad = 1, w_pad = 1, h_pad=1)
# plt.savefig(os.path.join(results_image_dir_path, 'k-svd重构第三次去噪结果' + '_patchsize_' + str(patch_size) +  tid_maker() + '.png'))
# plt.show()

#################################################################################

# Write Results
pd_data = None
pd_data0 = pd.DataFrame(orgin_data, columns=["orgin_data"])
pd_data1 = pd.DataFrame(orgin_data_distorted, columns=["orgin_data_distorted"])
pd_data = pd.concat([pd_data0, pd_data1], axis=1)
for i in range(len(titles)):
    line_label = titles[i].replace("\n", " ")
    pd_data_new = pd.DataFrame(reconstruct_orgin_data_1d_denoise[i], columns=[line_label])
    pd_data_new_resdual = pd.DataFrame(resdual_array[i],columns=[line_label + '_resdual'])
    pd_data = pd.concat([pd_data, pd_data_new,pd_data_new_resdual], axis=1)




csv_data_name = os.path.basename(txt_orgin_data_distorted_path).split(".txt")[0] + '_patchsize_' + str(patch_size) + '_denoise_' + tid_maker() + '.csv'
csv_data_path = os.path.join(csv_data_dir_path, csv_data_name)
csv_data_path

pd_data.to_csv(csv_data_path, sep=',',  mode='w',float_format='%.4f',index=None, header=True)





